from transformers import AutoTokenizer, AutoModelForCausalLM
from argparse import ArgumentParser
import torch
from torch.utils.data import DataLoader
from tools import read_jsonl, DynamicDataset, collate_fn, read_test_jsonl
from tqdm import tqdm
import jsonlines
import pdb


def hyperparameters():
    parser = ArgumentParser(description="Run ORCA2 for Rule Probing")

    parser.add_argument("--model", type=str, default="", help="Model path")
    parser.add_argument("--data_dir", type=str, default="ecare")
    parser.add_argument("--data_file", type=str, default="final_data_test")
    parser.add_argument("--output_dir", type=str, default="output")
    parser.add_argument("--output_file", type=str, default="orca-2-13b-v3.jsonl")

    parser.add_argument("--batch_size", type=int, default=72)
    parser.add_argument("--max_length", type=int, default=128)
    parser.add_argument("--temperature", type=float, default=0.00001)

    return parser.parse_args()


if __name__ == "__main__":
    args = hyperparameters()
    print(args)

    # load pre-trained weights and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.pad_token_id = (0)
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(args.model, device_map="auto")

    # load data & define dataloader
    # data = read_jsonl(f"./data/{args.data_dir}/{args.data_file}.jsonl")
    data = read_test_jsonl(f"./data/{args.data_dir}/{args.data_file}.jsonl")
    dataset = DynamicDataset(*data)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    # define output file
    fo = jsonlines.open(f"./{args.output_dir}/{args.output_file}", "w")

    # define system message
    system_message = "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."
    
    # run model
    with torch.no_grad():
        model.eval()
        for batch in tqdm(dataloader):
            
            inputs, _ = batch
            user_messages = inputs

            # premises, hypotheses1, hypotheses2, rules, labels = batch
            # if "v3" in args.output_file:
            #     user_messages = [f"You are given one premise and two hypotheses. You need to choose a more plausible hypothesis that follows the question in the premise. You can choose only one hypothesis. You answer should follow the format like \"Answer: Hypothesis(1 or 2) is more plausible.\nExplanation: ___\"\nPremise: {p}\nHypothesis1: {h1}\nHypothesis2: {h2}" for p, h1, h2 in zip(premises, hypotheses1, hypotheses2)]
            # else:
            #     user_messages = [f"You are given one premise, two hypotheses, and one rule. You need to choose a more plausible hypothesis that follows the question in the premise, the selection can based on the rule. You can choose only one hypothesis. You answer should follow the format like \"Answer: Hypothesis(1 or 2) is more plausible.\nExplanation: ___\"\nPremise: {p}\nHypothesis1: {h1}\nHypothesis2: {h2}\nRule: {r}" for p, h1, h2, r in zip(premises, hypotheses1, hypotheses2, rules)]
            # user_messages = [f"{p} Hypothesis1 or Hypothesis2?\nHypothesis1: {h1}\nHypothesis2: {h2} \nYou answer should follow the format like \"Answer: Hypothesis(1 or 2) is more plausible.\nExplanation: ___\"" for p, h1, h2 in zip(premises, hypotheses1, hypotheses2)]
            # user_messages = [f"Question: {p} Hypothesis1 or Hypothesis2?\nHypothesis1: {h1}\nHypothesis2: {h2} \nDo you think \"{r}\" can be used to answer this question? You answer should follow the format like \"Answer: Yes or No.\"" for p, h1, h2, r in zip(premises, hypotheses1, hypotheses2, rules)]
            prompts = [f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" for user_message in user_messages]

            inputs = tokenizer(prompts, return_tensors='pt', padding=True)
            output_ids = model.generate(inputs.input_ids.cuda(), 
                                        attention_mask=inputs.attention_mask.cuda(), 
                                        temperature=args.temperature, 
                                        max_new_tokens=args.max_length,
                                        do_sample=True,
                                        top_k=50)
            answers = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            
            # for p, h1, h2, r, l, a in zip(premises, hypotheses1, hypotheses2, rules, labels.tolist(), answers):
            #     fo.write({"premise": p, "hypothesis1": h1, "hypothesis2": h2, "general_rule": r, "label": l, "answer": a})
            for ipt, a in zip(user_messages, answers):
                fo.write({"input": ipt, "R": a})
    fo.close()













